#  Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


from dcs_core.core.common.models.metric import MetricsType
from dcs_core.core.datasource.search_datasource import SearchIndexDataSource
from dcs_core.core.datasource.sql_datasource import SQLDataSource
from dcs_core.core.metric.base import FieldMetrics, MetricIdentity


class MinMetric(FieldMetrics):

    """
    MinMetric is a class that represents a metric test is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.MIN,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_min(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_min(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class MaxMetric(FieldMetrics):

    """
    MaxMetric is a class that represents a metric that is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.MAX,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_max(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_max(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class AvgMetric(FieldMetrics):

    """
    AvgMetric is a class that represents a metric that is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.AVG,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_avg(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_avg(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class SumMetric(FieldMetrics):

    """
    SumMetric is a class that represents a metric that is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.SUM,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_sum(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_sum(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class VarianceMetric(FieldMetrics):

    """
    VarianceMetric is a class that represents a metric test is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.VARIANCE,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_variance(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_variance(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class StddevMetric(FieldMetrics):

    """
    StddevMetric is a class that represents a metric test generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.STDDEV,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_stddev(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_stddev(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class DuplicateCountMetric(FieldMetrics):

    """
    DuplicateCountMetric is a class that represents a metric test is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.DUPLICATE_COUNT,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_duplicate_count(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_duplicate_count(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class NullCountMetric(FieldMetrics):

    """
    NullCountMetric is a class that represents a metric test is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.NULL_COUNT,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_null_count(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_null_count(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class NullPercentageMetric(FieldMetrics):

    """
    NullPercentageMetric is a class that represents a metric test is generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.NULL_PERCENTAGE,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_null_percentage(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_null_percentage(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class DistinctCountMetric(FieldMetrics):

    """
    DistinctCountMetric is a class that represents a metric test generated by a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.DISTINCT_COUNT,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_distinct_count(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_distinct_count(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class EmptyStringCountMetric(FieldMetrics):
    """
    EmptyStringCountMetric is a class that represents a metric for counting empty strings in a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.EMPTY_STRING_COUNT,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name if self.table_name else None,
            index_name=self.index_name if self.index_name else None,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_empty_string_count(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_empty_string_count(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")


class EmptyStringPercentageMetric(FieldMetrics):
    """
    EmptyStringPercentageMetric is a class that represents a metric for counting empty strings in a data source.
    """

    def get_metric_identity(self):
        return MetricIdentity.generate_identity(
            metric_type=MetricsType.EMPTY_STRING_PERCENTAGE,
            metric_name=self.name,
            data_source=self.data_source,
            field_name=self.field_name,
            table_name=self.table_name,
            index_name=self.index_name,
        )

    def _generate_metric_value(self):
        if isinstance(self.data_source, SQLDataSource):
            return self.data_source.query_get_empty_string_percentage(
                table=self.table_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        elif isinstance(self.data_source, SearchIndexDataSource):
            return self.data_source.query_get_empty_string_percentage(
                index_name=self.index_name,
                field=self.field_name,
                filters=self.filter_query if self.filter_query else None,
            )
        else:
            raise ValueError("Invalid data source type")
